{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a8ee28d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from multiprocess import Pool\n",
    "import itertools\n",
    "import numpy as np\n",
    "\n",
    "def chunks(l, n):\n",
    "    for i in range(0, len(l), n):\n",
    "        yield (l[i: i + n], i // n)\n",
    "\n",
    "def multiprocessing(strings, function, cores=6, returned=True):\n",
    "    df_split = chunks(strings, len(strings) // cores)\n",
    "    pool = Pool(cores)\n",
    "    pooled = pool.map(function, df_split)\n",
    "    pool.close()\n",
    "    pool.join()\n",
    "\n",
    "    if returned:\n",
    "        return list(itertools.chain(*pooled))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2fb31c7c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    dev: Dataset({\n",
      "        features: ['audio', 'timestamps_start', 'timestamps_end', 'speakers'],\n",
      "        num_rows: 216\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['audio', 'timestamps_start', 'timestamps_end', 'speakers'],\n",
      "        num_rows: 232\n",
      "    })\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "ds = load_dataset(\"diarizers-community/voxconverse\")\n",
    "\n",
    "print(ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fb88cb25",
   "metadata": {},
   "outputs": [],
   "source": [
    "import string\n",
    "import soundfile as sf\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "def convert_rttm(chunk, filename = 'audio'):\n",
    "    rttm = []\n",
    "    for start, end, speaker in chunk:\n",
    "        duration = end - start\n",
    "        rttm.append(f\"SPEAKER {filename} 1 {start:.4f} {duration:.4f} <NA> <NA> <NA> <NA> {speaker}\")\n",
    "    return '\\n'.join(rttm)\n",
    "\n",
    "def convert_textgrid(segments):\n",
    "    tiers = defaultdict(list)\n",
    "    for start, end, speaker in segments:\n",
    "        tiers[speaker].append((start, end))\n",
    "\n",
    "    min_time = min(start for start, _, _ in segments)\n",
    "    max_time = max(end for _, end, _ in segments)\n",
    "\n",
    "    textgrid = []\n",
    "    textgrid.append(\"File type = \\\"ooTextFile\\\"\")\n",
    "    textgrid.append(\"Object class = \\\"TextGrid\\\"\")\n",
    "    textgrid.append(\"\")\n",
    "    textgrid.append(f\"xmin = {min_time:.2f}\")\n",
    "    textgrid.append(f\"xmax = {max_time:.2f}\")\n",
    "    textgrid.append(\"tiers? <exists>\")\n",
    "    textgrid.append(f\"size = {len(tiers)}\")\n",
    "    textgrid.append(\"item []:\")\n",
    "\n",
    "    for i, (speaker, intervals) in enumerate(tiers.items(), start=1):\n",
    "        textgrid.append(f\"    item [{i}]:\")\n",
    "        textgrid.append(\"        class = \\\"IntervalTier\\\"\")\n",
    "        textgrid.append(f\"        name = \\\"{speaker}\\\"\")\n",
    "        textgrid.append(f\"        xmin = {min_time:.2f}\")\n",
    "        textgrid.append(f\"        xmax = {max_time:.2f}\")\n",
    "        textgrid.append(f\"        intervals: size = {len(intervals)}\")\n",
    "\n",
    "        for j, (start, end) in enumerate(intervals, start=1):\n",
    "            textgrid.append(f\"        intervals [{j}]:\")\n",
    "            textgrid.append(f\"            xmin = {start:.2f}\")\n",
    "            textgrid.append(f\"            xmax = {end:.2f}\")\n",
    "            textgrid.append(f\"            text = \\\"{speaker}\\\"\")\n",
    "            \n",
    "    return '\\n'.join(textgrid)\n",
    "\n",
    "timestamps = [i * 0.02 for i in range(1500 + 1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2755f299",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !rm -rf voxconverse\n",
    "# !mkdir voxconverse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3987a18d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "def loop(indices):\n",
    "    indices, _ = indices\n",
    "    ds = load_dataset(\"diarizers-community/voxconverse\")\n",
    "    data = []\n",
    "    for k, key in tqdm(indices):\n",
    "        row = ds[key][k]\n",
    "        audio = row['audio']['array']\n",
    "        chunks, temp = [], []\n",
    "        argsort = np.argsort(row['timestamps_start'])\n",
    "        timestamps_start = [row['timestamps_start'][i] for i in argsort]\n",
    "        timestamps_end = [row['timestamps_end'][i] for i in argsort]\n",
    "        speakers = [row['speakers'][i] for i in argsort]\n",
    "        start = timestamps_start[0]\n",
    "        max_len = 30\n",
    "        for i in range(len(timestamps_start)):\n",
    "            l = timestamps_end[i] - start\n",
    "            if l >= max_len:\n",
    "                chunks.append(temp)\n",
    "                temp = [[timestamps_start[i], timestamps_end[i], speakers[i]]]\n",
    "                start = timestamps_start[i]\n",
    "                continue\n",
    "            else:\n",
    "                temp.append([timestamps_start[i], timestamps_end[i], speakers[i]])\n",
    "\n",
    "        if len(temp):\n",
    "            chunks.append(temp)\n",
    "\n",
    "        for no, chunk in enumerate(chunks):\n",
    "            speakers = []\n",
    "            for i in range(len(chunk)):\n",
    "                if chunk[i][-1] not in speakers:\n",
    "                    speakers.append(chunk[i][-1])\n",
    "            \n",
    "            try:          \n",
    "                start_time = chunk[0][0]\n",
    "                end_time = max([c[1] for c in chunk])\n",
    "            except Exception as e:\n",
    "                continue\n",
    "                \n",
    "            if round(end_time - start_time, 2) > max_len:\n",
    "                continue\n",
    "            \n",
    "            y = audio[int(16000 * start_time): int(16000 * end_time)]\n",
    "            audio_filename = f'voxconverse/{key}-{k}-{no}.mp3'\n",
    "            if not os.path.exists(audio_filename):\n",
    "                sf.write(audio_filename, y, 16000)\n",
    "            \n",
    "            ts = []\n",
    "            for i in range(len(chunk)):\n",
    "                index = speakers.index(chunk[i][-1])\n",
    "                start = min(timestamps, key=lambda t: abs(t - (chunk[i][0] - start_time)))\n",
    "                end = min(timestamps, key=lambda t: abs(t - (chunk[i][1] - start_time)))\n",
    "                speaker_name = f'speaker {string.ascii_uppercase[index]}'\n",
    "                chunk[i][-1] = speaker_name\n",
    "                chunk[i][0] = start\n",
    "                chunk[i][1] = end\n",
    "                t = f\"<|{start:.2f}|> {speaker_name}<|{end:.2f}|>\"\n",
    "                ts.append(t)\n",
    "                \n",
    "            ts = ''.join(ts)\n",
    "            rttm = convert_rttm(chunk)\n",
    "            textgrid = convert_textgrid(chunk)\n",
    "            \n",
    "            data.append({\n",
    "                'question': 'diarize the audio using whisper format',\n",
    "                'answer': ts,\n",
    "                'audio_filename': audio_filename,\n",
    "            })\n",
    "            data.append({\n",
    "                'question': 'diarize the audio using rttm format',\n",
    "                'answer': rttm,\n",
    "                'audio_filename': audio_filename,\n",
    "            })\n",
    "            data.append({\n",
    "                'question': 'diarize the audio using textgrid format',\n",
    "                'answer': textgrid,\n",
    "                'audio_filename': audio_filename,\n",
    "            })\n",
    "            \n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2f1d4a50",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.36it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.72it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.08it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.65it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.72it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.64it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.16it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  7.38it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.18it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.61it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.41it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  8.07it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.02it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.31it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.63it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.70it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.42it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  7.56it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.88it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  7.22it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 28.57it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 45.49it/s]\n"
     ]
    }
   ],
   "source": [
    "indices = list(range(len(ds['dev'])))\n",
    "indices = [(i, 'dev') for i in indices]\n",
    "prepared_validation = multiprocessing(indices, loop, cores = min(len(indices), 20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "48aabe86",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  5.71it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.75it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.83it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  5.30it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.81it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.88it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.69it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  5.12it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.75it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  5.33it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  5.12it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.59it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  5.79it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.64it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  5.00it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.15it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.99it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:01<00:00,  5.64it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  5.36it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:02<00:00,  4.80it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.35it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 13.71it/s]\n"
     ]
    }
   ],
   "source": [
    "indices = list(range(len(ds['test'])))\n",
    "indices = [(i, 'test') for i in indices]\n",
    "prepared_test = multiprocessing(indices, loop, cores = min(len(indices), 20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d78101b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6564, 13959)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(prepared_validation), len(prepared_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e1350611",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'question': 'diarize the audio using rttm format',\n",
       " 'answer': 'SPEAKER audio 1 0.0000 9.9200 <NA> <NA> <NA> <NA> speaker A\\nSPEAKER audio 1 9.9600 1.1200 <NA> <NA> <NA> <NA> speaker B\\nSPEAKER audio 1 11.7200 12.1200 <NA> <NA> <NA> <NA> speaker B',\n",
       " 'audio_filename': 'voxconverse/dev-215-4.mp3'}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prepared_validation[-2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8c024481",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "pd.DataFrame(prepared_validation).to_parquet('voxconverse-validation.parquet')\n",
    "pd.DataFrame(prepared_test).to_parquet('voxconverse-test.parquet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "59e4627e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Uploading files using Xet Storage..\n",
      "Uploading...: 100%|██████████████████████████| 424k/424k [00:06<00:00, 68.2kB/s]\n",
      "https://huggingface.co/datasets/mesolitica/Speaker-Diarization-Instructions/blob/main//data/voxconverse_validation-00000-of-00001.parquet\n"
     ]
    }
   ],
   "source": [
    "!huggingface-cli upload mesolitica/Speaker-Diarization-Instructions \\\n",
    "voxconverse-validation.parquet /data/voxconverse_validation-00000-of-00001.parquet \\\n",
    "--repo-type=dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "4c52b8c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Uploading files using Xet Storage..\n",
      "Uploading...: 100%|█████████████████████████| 1.02M/1.02M [00:05<00:00, 203kB/s]\n",
      "https://huggingface.co/datasets/mesolitica/Speaker-Diarization-Instructions/blob/main//data/voxconverse_test-00000-of-00001.parquet\n"
     ]
    }
   ],
   "source": [
    "!huggingface-cli upload mesolitica/Speaker-Diarization-Instructions \\\n",
    "voxconverse-test.parquet /data/voxconverse_test-00000-of-00001.parquet \\\n",
    "--repo-type=dataset"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
